-- -*- lua -*-
--
-- SipHash is a family of hash functions, parameterized by the number of
-- rounds that run per 8-byte input block and the number of rounds that
-- run at the end.  SipHash was designed to be good for short inputs and
-- to resist "hash flooding" attacks, where users attack a hash table
-- with many inputs that map to the same area of the hash table.
--
-- We provide implementations of SipHash with a normal scalar DynASM
-- backend as well as parallel SSE and AVX2 backends.
--
-- Because we use SipHash as a hash function for fixed-sized inputs, we
-- can simplify processing of the tail word.  This simplification is
-- enabled unless a true value for "as_specified" is passed.
--
-- This implementation was based on the reference implementation.  See
-- https://131002.net/siphash/ for more details.

module(..., package.seeall)

local bit  = require("bit")
local dasm = require("dasm")
local ffi  = require("ffi")
local S    = require("syscall")
local lib  = require("core.lib")

local debug = false

local cpuinfo = lib.readfile("/proc/cpuinfo", "*a")
assert(cpuinfo, "failed to read /proc/cpuinfo for hardware check")
local have_avx2 = cpuinfo:match("avx2")
local have_sse2 = cpuinfo:match("sse2")

|.arch x64
|.actionlist actions

__anchor = {}
local function finish (name, prototype, Dst)
   local mcode, size = Dst:build()
   table.insert(__anchor, mcode)
   if debug then
      print("mcode dump: "..name)
      dasm.dump(mcode, size)
   end
   return ffi.cast(prototype, mcode)
end

function random_sip_hash_key()
   return lib.random_bytes(16)
end

function sip_hash_key_from_seed(seed)
   local res = ffi.new('uint8_t[16]')
   ffi.cast('uint32_t*', res)[0] = assert(tonumber(seed))
   return res
end

local function load_initial_state(key)
   -- Initial state constants.
   local state = ffi.new('uint64_t[4]',
                         { 0x736f6d6570736575ULL,
                           0x646f72616e646f6dULL,
                           0x6c7967656e657261ULL,
                           0x7465646279746573ULL })

   -- Mix key into state constants.
   key = ffi.cast('uint64_t*', key or random_sip_hash_key())
   state[0] = bit.bxor(state[0], key[0])
   state[1] = bit.bxor(state[1], key[1])
   state[2] = bit.bxor(state[2], key[0])
   state[3] = bit.bxor(state[3], key[1])

   return state
end

-- Scalar x86-64 backend for the SipHash implementation.
local function X86_64()
   local asm = {}
   local Dst
   local allregs = {"rax", "rcx", "rdx", "rbx", "rsp", "rbp", "rsi", "rdi",
                    "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15"}
   local regnums = {}
   for i,reg in ipairs(allregs) do regnums[reg] = i-1 end
   -- Map value indexes from the program to registers.  Don't allocate
   -- to rdi; it's the input data.
   local scratchregs = {"r8", "r9", "r10", "r11", "rax", "rcx", "rdx", "rsi"}
   local function regnum(n) return assert(regnums[assert(scratchregs[n+1])]) end

   function asm.init(key)
      local initial_state = load_initial_state(key)
      table.insert(__anchor, initial_state)
      | mov64 Rq(regnum(0)), initial_state[0]
      | mov64 Rq(regnum(1)), initial_state[1]
      | mov64 Rq(regnum(2)), initial_state[2]
      | mov64 Rq(regnum(3)), initial_state[3]
   end
   function asm.add(dst, other)
      | add Rq(regnum(dst)), Rq(regnum(other))
   end
   function asm.shl(reg, bits)
      | shl Rq(regnum(reg)), bits
   end
   function asm.ior(dst, other)
      | or Rq(regnum(dst)), Rq(regnum(other))
   end
   function asm.xor(dst, other)
      | xor Rq(regnum(dst)), Rq(regnum(other))
   end
   function asm.rol(reg, bits)
      | rol Rq(regnum(reg)), bits
   end
   function asm.copy_argument(dst, arg)
      -- For use by make_sip_hash_u64.
      local arg = assert(({"rdi", "rsi"})[arg])
      if regnum(dst) ~= regnums[arg] then
         | mov Rq(regnum(dst)), Rq(regnums[arg])
      end
   end
   function asm.load_u64_and_advance(dst)
      | mov Rq(regnum(dst)), [rdi]
      | add rdi, 8
   end
   function asm.load_u32_and_advance(dst)
      -- Automatically zero-extending.
      | mov Rd(regnum(dst)), [rdi]
      | add rdi, 4
   end
   function asm.load_u16_and_advance(dst)
      | movzx Rq(regnum(dst)), word [rdi]
      | add rdi, 2
   end
   function asm.load_u8_and_advance(dst)
      | movzx Rq(regnum(dst)), byte [rdi]
      | add rdi, 1
   end
   function asm.load_imm8(dst, imm)
      | xor Rq(regnum(dst)), Rq(regnum(dst))
      if imm ~= 0 then
         | mov Rb(regnum(dst)), imm
      end
   end
   function asm.ret(reg)
      if regnum(reg) ~= regnums["rax"] then
         | mov rax, Rq(regnum(reg))
      end
      | shr rax, 32
      -- The ctable needs the hash to not be 0xFFFFFFFF.  If we did the
      -- lshift there instead of here, we'd have to compensate for that
      -- weird thing where LuaJIT's bitops can produce negative numbers.
      | shl eax, 1
      | ret
   end
   function asm.finish(name, Dst)
      return finish(name.."_x1",
                    ffi.typeof("uint32_t (*)(void *)"),
                    Dst)
   end
   function asm:assemble(name, gen)
      Dst = dasm.new(actions)
      gen(self)
      return self.finish(name, Dst)
   end
   return asm
end

-- Parallel SSE backend for the SipHash implementation that can hash two
-- inputs at once.
local function SSE(stride)
   local asm = {}
   local Dst
   function asm.init(key)
      local initial_state = load_initial_state(key)
      table.insert(__anchor, initial_state)
      | mov rax, initial_state
      | movddup xmm0, [rax]
      | movddup xmm1, [rax+8]
      | movddup xmm2, [rax+16]
      | movddup xmm3, [rax+24]
   end
   function asm.add(dst, other)
      | paddq xmm(dst), xmm(other)
   end
   function asm.shl(reg, bits)
      | psllq xmm(reg), bits
   end
   function asm.ior(dst, other)
      | por xmm(dst), xmm(other)
   end
   function asm.xor(dst, other)
      | pxor xmm(dst), xmm(other)
   end
   function asm.rol(reg, bits)
      | movupd xmm5, xmm(reg)
      | psllq xmm5, bits
      | psrlq xmm(reg), (64 - bits)
      | por xmm(reg), xmm5
   end
   function asm.load_u64_and_advance(dst)
      | movlpd xmm(dst), qword [rdi]
      | movhpd xmm(dst), qword [rdi+stride]
      | add rdi, 8
   end
   function asm.load_u32_and_advance(dst)
      | movss xmm(dst), dword [rdi]
      | pinsrd xmm(dst), dword [rdi + stride], 2
      | add rdi, 4
   end
   function asm.load_u16_and_advance(dst)
      | pxor xmm(dst), xmm(dst)
      | pinsrw xmm(dst), word [rdi], 0
      | pinsrw xmm(dst), word [rdi + stride], 4
      | add rdi, 2
   end
   function asm.load_u8_and_advance(dst)
      | pxor xmm(dst), xmm(dst)
      | pinsrb xmm(dst), byte [rdi], 0
      | pinsrb xmm(dst), byte [rdi + stride], 8
      | add rdi, 1
   end
   function asm.load_imm8(dst, imm)
      | xor rax, rax
      if imm ~= 0 then
         | mov al, imm
      end
      | pinsrq xmm(dst), rax, 0
      | pinsrq xmm(dst), rax, 1
   end
   function asm.ret(reg)
      -- Extract high 31 bits from each 64-bit value to uint32_t[2] in
      -- rsi.
      | pslld xmm(reg), 1
      | pextrd dword [rsi], xmm(reg), 1
      | pextrd dword [rsi+4], xmm(reg), 3
      | ret
   end
   function asm:assemble(name, gen)
      Dst = dasm.new(actions)
      gen(self)
      return finish(name.."_x2",
                    ffi.typeof("void (*)(uint8_t *, uint32_t *)"),
                    Dst)
   end
   return asm
end

-- Parallel AVX2 backend for the SipHash implementation that can hash
-- four inputs at once.
local function AVX2(stride)
   local asm = {}
   local Dst
   function asm.init(key)
      local initial_state = load_initial_state(key)
      table.insert(__anchor, initial_state)
      | vzeroupper
      | mov rax, initial_state
      | vbroadcastsd ymm0, qword [rax]
      | vbroadcastsd ymm1, qword [rax+8]
      | vbroadcastsd ymm2, qword [rax+16]
      | vbroadcastsd ymm3, qword [rax+24]
   end
   function asm.add(dst, other)
      | vpaddq ymm(dst), ymm(dst), ymm(other)
   end
   function asm.shl(reg, bits)
      | vpsllq ymm(reg), ymm(reg), bits
   end
   function asm.ior(dst, other)
      | vpor ymm(dst), ymm(dst), ymm(other)
   end
   function asm.xor(dst, other)
      | vpxor ymm(dst), ymm(dst), ymm(other)
   end
   function asm.rol(reg, bits)
      | vpsllq ymm5, ymm(reg), bits
      | vpsrlq ymm(reg), ymm(reg), (64 - bits)
      | vpor ymm(reg), ymm(reg), ymm5
   end
   function asm.load_u64_and_advance(dst)
      -- TODO: Use a parallel load here.
      assert(dst > 5)
      | vpinsrq xmm(dst), xmm(dst), qword [rdi], 0
      | vpinsrq xmm(dst), xmm(dst), qword [rdi+stride*1], 1
      | vpinsrq xmm5, xmm5, qword [rdi+stride*2], 0
      | vpinsrq xmm5, xmm5, qword [rdi+stride*3], 1
      | vinsertf128 ymm(dst), ymm(dst), xmm5, 1
      | add rdi, 8
   end
   function asm.load_u32_and_advance(dst)
      assert(dst > 5)
      | mov eax, [rdi]
      | mov edx, [rdi+stride]
      | vpinsrq xmm(dst), xmm(dst), rax, 0
      | vpinsrq xmm(dst), xmm(dst), rdx, 1
      | mov eax, [rdi+stride*2]
      | mov edx, [rdi+stride*3]
      | vpinsrq xmm5, xmm5, rax, 0
      | vpinsrq xmm5, xmm5, rdx, 1
      | vinsertf128 ymm(dst), ymm(dst), xmm5, 1
      | add rdi, 4
   end
   function asm.load_u16_and_advance(dst)
      assert(dst > 5)
      | movzx rax, word [rdi]
      | movzx rdx, word [rdi+stride]
      | vpinsrq xmm(dst), xmm(dst), rax, 0
      | vpinsrq xmm(dst), xmm(dst), rdx, 1
      | movzx rax, word [rdi+stride*2]
      | movzx rdx, word [rdi+stride*3]
      | vpinsrq xmm5, xmm5, rax, 0
      | vpinsrq xmm5, xmm5, rdx, 1
      | vinsertf128 ymm(dst), ymm(dst), xmm5, 1
      | add rdi, 2
   end
   function asm.load_u8_and_advance(dst)
      assert(dst > 5)
      | movzx rax, byte [rdi]
      | movzx rdx, byte [rdi+stride]
      | vpinsrq xmm(dst), xmm(dst), rax, 0
      | vpinsrq xmm(dst), xmm(dst), rdx, 1
      | movzx rax, byte [rdi+stride*2]
      | movzx rdx, byte [rdi+stride*3]
      | vpinsrq xmm5, xmm5, rax, 0
      | vpinsrq xmm5, xmm5, rdx, 1
      | vinsertf128 ymm(dst), ymm(dst), xmm5, 1
      | add rdi, 1
   end
   function asm.load_imm8(dst, imm)
      | xor rax, rax
      if imm ~= 0 then
         | mov al, imm
      end
      | vpinsrq xmm(dst), xmm(dst), rax, 0
      | vpinsrq xmm(dst), xmm(dst), rax, 1
      | vinsertf128 ymm(dst), ymm(dst), xmm(dst), 1
   end
   local function imm8_control(a, b, c, d)
      assert(bit.band(bit.bor(a,b,c,d), 3) == bit.bor(a,b,c,d))
      return a + b*4 + c*4*4 + d*4*4*4
   end
   function asm.ret(reg)
      -- Extract high 32 bits from each 64-bit value.  For ridiculous
      -- reasons, the way we have to do this is to read in a set of
      -- indexes into the 32-bit words from memory, then permute, then
      -- write out.
      local control = ffi.new('uint32_t[8]', 1, 3, 5, 7, 0, 0, 0, 0)
      table.insert(__anchor, control)
      | mov rax, control
      | vmovdqu ymm5, [rax]
      | vpermd ymm(reg), ymm5, ymm(reg)
      | vpslld xmm(reg), xmm(reg), 1
      | vmovdqu oword [rsi], xmm(reg)
      | vzeroupper
      | ret
   end
   function asm:assemble(name, gen)
      Dst = dasm.new(actions)
      gen(self)
      return finish(name.."_x4",
                    ffi.typeof("void (*)(uint8_t *, uint32_t *)"),
                    Dst)
   end
   return asm
end

-- Portable backend for the SipHash implementation to use as a
-- reference.
local function Simulator()
   local asm = {}
   local input, output
   local state
   function asm.init(key)
      local initial_state = load_initial_state(key)
      state = ffi.new('uint64_t[8]')
      for i=0,3 do state[i] = initial_state[i] end
   end
   function asm.add(dst, other)
      state[dst] = state[dst] + state[other]
   end
   function asm.shl(reg, bits)
      state[reg] = bit.lshift(state[reg], bits)
   end
   function asm.ior(dst, other)
      state[dst] = bit.bor(state[dst], state[other])
   end
   function asm.xor(dst, other)
      state[dst] = bit.bxor(state[dst], state[other])
   end
   function asm.rol(reg, bits)
      state[reg] = bit.rol(state[reg], bits)
   end
   function asm.load_u64_and_advance(dst)
      state[dst] = ffi.cast('uint64_t*', input)[0]
      input = input + 8
   end
   function asm.load_u32_and_advance(dst)
      state[dst] = ffi.cast('uint32_t*', input)[0]
      input = input + 4
   end
   function asm.load_u16_and_advance(dst)
      state[dst] = ffi.cast('uint16_t*', input)[0]
      input = input + 2
   end
   function asm.load_u8_and_advance(dst)
      state[dst] = input[0]
      input = input + 1
   end
   function asm.load_imm8(dst, imm)
      state[dst] = imm
   end
   function asm.ret(reg)
      output = tonumber(bit.rshift(state[reg], 32))
      output = ffi.new('uint32_t[1]', bit.lshift(output, 1))[0]
   end
   function asm:assemble(name, gen)
      return function(ptr)
         input, output = ffi.cast('uint8_t*', ptr), nil
         gen(self)
         input = nil
         return output
      end
   end
   return asm
end

local sip_hash_config = {
   size={required=true}, stride={}, key={default=false},
   c={default=2}, d={default=4}, as_specified={default=false}, width={default=1}
}
local function make_sip_hash(assembler, opts)
   function siphash(asm)
      -- Arguments:
      -- rdi: packed keys as pointer
      -- rsi: for parallel implementations, an output uint32_t[4]

      -- Working registers:
      -- Registers 0-3 map to SipHash variables v0-v3
      -- Registers 6 and 7 used as scratch.

      local function sipround()
         asm.add(0, 1)
         asm.rol(1, 13)
         asm.xor(1, 0)
         asm.rol(0, 32)
         asm.add(2, 3)
         asm.rol(3, 16)
         asm.xor(3, 2)
         asm.add(0, 3)
         asm.rol(3, 21)
         asm.xor(3, 0)
         asm.add(2, 1)
         asm.rol(1, 17)
         asm.xor(1, 2)
         asm.rol(2, 32)
      end

      local function process(input)
         asm.xor(3, input)
         for i=1,opts.c do sipround() end
         asm.xor(0, input)
      end

      -- Initialization phase.
      asm.init(opts.key)

      -- Compression phase.
      for i=1,opts.size/8 do
         asm.load_u64_and_advance(6)
         process(6)
      end
      -- Load tail word and process it.
      if opts.as_specified then
         asm.load_imm8(6, bit.band(opts.size, 0xff))
         asm.shl(6, 56)
         for i=1,opts.size%8 do
            asm.load_u8_and_advance(7)
            if i > 1 then asm.shl(7, (i - 1) * 8) end
            asm.ior(6, 7)
         end
         process(6)
      elseif opts.size%8 ~= 0 then
         -- Fixed-size simplification: no need to add in size byte, we
         -- can use different byte orders if it's more convenient, and
         -- we don't have to do anything at all if the size is a
         -- multiple of 8.
         if opts.size%8 >= 4 then
            asm.load_u32_and_advance(6)
         else
            asm.xor(6, 6)
         end
         if opts.size%4 >= 2 then
            asm.load_u16_and_advance(7)
            asm.shl(6, 16)
            asm.ior(6, 7)
         end
         if opts.size%2 ~= 0 then
            asm.load_u8_and_advance(7)
            asm.shl(6, 8)
            asm.ior(6, 7)
         end
         process(6)
      end

      -- Finalization.
      asm.load_imm8(6, 0xff)
      asm.xor(2, 6)
      for i=1,opts.d do sipround() end
      asm.xor(0, 1)
      asm.xor(2, 3)
      asm.xor(0, 2)
      asm.ret(0)
   end

   opts = lib.parse(opts, sip_hash_config)
   if not opts.stride then opts.stride = opts.size end
   local asm = assembler(opts.stride)
   return asm:assemble("siphash_"..opts.c.."_"..opts.d, siphash)
end

-- A special implementation to hash immediate values; requires our
-- fixed-size simplification.
function make_u64_hash(opts)
   local opts = lib.deepcopy(opts)
   opts.size = 8
   assert(not opts.as_specified)
   local function ImmX86_64()
      local asm = X86_64()
      function asm.load_u64_and_advance(dst)
         asm.copy_argument(dst, 1)
      end
      function asm.finish(name, Dst)
         return finish(name.."_u64",
                       ffi.typeof("uint32_t (*)(uint64_t)"),
                       Dst)
      end
      asm.load_u32_and_advance = error
      asm.load_u16_and_advance = error
      asm.load_u8_and_advance = error
      return asm
   end
   return make_sip_hash(ImmX86_64, opts)
end

local function make_hash1(opts)
   return make_sip_hash(X86_64, opts)
end
local function make_hash2(opts)
   if have_sse2 then return make_sip_hash(SSE, opts) end
   local hash = make_hash1(opts)
   local stride = opts.stride or opts.size
   return function(ptr, result)
      ptr = ffi.cast('uint8_t*', ptr)
      result = ffi.cast('uint32_t*', result)
      result[0] = hash(ptr)
      result[1] = hash(ptr + stride)
   end
end
local function make_hash4(opts)
   if have_avx2 then return make_sip_hash(AVX2, opts) end
   local hash = make_hash2(opts)
   local stride = opts.stride or opts.size
   return function(ptr, result)
      ptr = ffi.cast('uint8_t*', ptr)
      result = ffi.cast('uint32_t*', result)
      hash(ptr, result)
      hash(ptr + stride*2, result + 2)
   end
end
local function make_reference_sip_hash(opts)
   return make_sip_hash(Simulator, opts)
end

make_hash = make_hash1

function make_multi_hash(opts)
   local hash1,hash2,hash4 = make_hash1(opts),make_hash2(opts),make_hash4(opts)
   local width = opts.width or 1
   local stride = (opts.stride or opts.size)
   if width == 1 then
      return function(input, output) output[0] = hash1(input) end
   end
   if width == 2 then return hash2 end
   if width == 4 then return hash4 end

   if width % 4 == 0 then
      return function(input, output)
         input = ffi.cast('uint8_t*', input)
         output = ffi.cast('uint32_t*', output)
         for i=1,width,4 do
            hash4(input, output)
            input = input + stride*4
            output = output + 4
         end
      end
   end

   return function(input, output)
      input = ffi.cast('uint8_t*', input)
      output = ffi.cast('uint32_t*', output)
      for i=1,bit.rshift(width, 2) do
         hash4(input, output)
         input = input + stride*4
         output = output + 4
      end
      if bit.band(width, 2) ~= 0 then
         hash2(input, output)
         input = input + stride*2
         output = output + 2
      end
      if bit.band(width, 1) ~= 0 then
         output[0] = hash1(input)
      end
   end
end

function selftest()
   local function union(a, b)
      local res = lib.deepcopy(a)
      for k,v in pairs(b) do res[k] = v end
      return res
   end
   local function test(opts)
      local size = opts.size
      local reference_hash = make_reference_sip_hash(opts)
      local hash1 = make_hash(opts)
      local test_input = lib.random_bytes(size)

      local ref_result = reference_hash(test_input)

      local function check(result, what)
         if ref_result ~= result then
            for k,v in pairs(opts) do print(k,v) end
            error('got '..what..'='..result..'; expected '..ref_result)
         end
      end

      check(hash1(test_input), 'scalar')

      local zero_hash = reference_hash(ffi.new("uint8_t[?]", size))
      for _,width in ipairs({1,2,4,8}) do
         local mbuf = ffi.new("uint8_t[?]", size*width)
         local result = ffi.new("uint32_t[?]", width)
         local mhash = make_multi_hash(union(opts, {width=width}))
         for i=0,width-1 do
            ffi.fill(mbuf, size*width)
            ffi.copy(mbuf+i*size, test_input, size) 
            mhash(mbuf, result)
            for elt=0,width-1 do
               if elt == i then
                  check(result[i], string.format('x%d[%d]', width, elt))
               elseif result[elt] ~= zero_hash then
                  error('got hash(0)='..result[elt]..'; expected '..zero_hash)
               end
            end
         end
      end
   end

   io.stdout:write("selftest: ")
   io.stdout:flush()
   local opts = { key=random_sip_hash_key() }
   for size=0,32 do
      opts.size = size
      io.stdout:write(".")
      io.stdout:flush()
      for c=0,2 do
         opts.c = c
         for d=0,4 do
            opts.d = d
            for _, as_specified in ipairs({true, false}) do
               opts.as_specified = as_specified
               test(opts)
            end
         end
      end
   end

   -- We need a hash function for immediates as well; test that here.
   opts.size, opts.as_specified = 8, false
   local val = ffi.new('uint64_t[1]', 0x12345678ULL)
   for c=0,2 do
      opts.c = c
      for d=0,4 do
         io.stdout:write(".")
         io.stdout:flush()
         opts.d = d
         val[0] = val[0] * 257
         local hash_mem = make_hash(opts)
         local hash_u64 = make_u64_hash(opts)
         assert(hash_mem(val) == hash_u64(val[0]))
      end
   end

   print("\nselftest ok")
end
